from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter("torch.nn.functional.max_pool3d")
@tensorrt_converter("torch.max_pool3d")
def convert_max_pool3d(ctx):
    # parse args
    input = get_arg(ctx, "input", pos=0, default=None)
    kernel_size = get_arg(ctx, "kernel_size", pos=1, default=None)
    stride = get_arg(ctx, "stride", pos=2, default=None)
    padding = get_arg(ctx, "padding", pos=3, default=0)
    dilation = get_arg(ctx, "dilation", pos=4, default=1)
    ceil_mode = get_arg(ctx, "ceil_mode", pos=5, default=False)

    # get input trt tensor (or create constant if it doesn't exist)
    input_trt = add_missing_trt_tensors(ctx.network, [input])[0]

    output = ctx.method_return

    # get kernel size
    if not isinstance(kernel_size, tuple):
        kernel_size = (kernel_size,) * 3

    # get stride
    if not isinstance(stride, tuple):
        stride = (stride,) * 3

    # get padding
    if not isinstance(padding, tuple):
        padding = (padding,) * 3

    layer = ctx.network.add_pooling_nd(
        input=input_trt, type=trt.PoolingType.MAX, window_size=kernel_size
    )

    layer.stride_nd = stride
    layer.padding_nd = padding

    if ceil_mode:
        layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP

    output._trt = layer.get_output(0)


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 4, 6, 7)])
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 5, 7, 8)])
def test_MaxPool3d_without_ceil_mode():
    return torch.nn.MaxPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=False)


@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 4, 6, 7)])
@add_module_test(torch.float32, torch.device("cuda"), [(1, 3, 5, 7, 8)])
def test_MaxPool3d_with_ceil_mode():
    return torch.nn.MaxPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=True)
